{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DAILY GLOBAL HISTORICAL CLIMATOLOGY NETWORK (GHCN-DAILY) \n",
    "\n",
    "Dense and regression tasks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import shutil\n",
    "import sys\n",
    "sys.path.append('../..')\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"1\"  # change to chosen GPU to use, nothing if work on CPU\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from tqdm import tqdm\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "import cartopy.crs as ccrs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from deepsphere import models, experiment_helper, plot, utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from GHCN_preprocessing import get_data, get_stations, sphereGraph, clean_nodes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy import sparse\n",
    "\n",
    "def regression_tikhonov(G, y, M, tau=0):\n",
    "    r\"\"\"Solve a regression problem on graph via Tikhonov minimization.\n",
    "    The function solves\n",
    "    .. math:: \\operatorname*{arg min}_x \\| M x - y \\|_2^2 + \\tau \\ x^T L x\n",
    "    if :math:`\\tau > 0`, and\n",
    "    .. math:: \\operatorname*{arg min}_x x^T L x \\ \\text{ s. t. } \\ y = M x\n",
    "    otherwise.\n",
    "    Parameters\n",
    "    ----------\n",
    "    G : :class:`pygsp.graphs.Graph`\n",
    "    y : array, length G.n_vertices\n",
    "        Measurements.\n",
    "    M : array of boolean, length G.n_vertices\n",
    "        Masking vector.\n",
    "    tau : float\n",
    "        Regularization parameter.\n",
    "    Returns\n",
    "    -------\n",
    "    x : array, length G.n_vertices\n",
    "        Recovered values :math:`x`.\n",
    "    Examples\n",
    "    --------\n",
    "    >>> from pygsp import graphs, filters, learning\n",
    "    >>> import matplotlib.pyplot as plt\n",
    "    >>>\n",
    "    >>> G = graphs.Sensor(N=100, seed=42)\n",
    "    >>> G.estimate_lmax()\n",
    "    Create a smooth ground truth signal:\n",
    "    >>> filt = lambda x: 1 / (1 + 10*x)\n",
    "    >>> filt = filters.Filter(G, filt)\n",
    "    >>> rs = np.random.RandomState(42)\n",
    "    >>> signal = filt.analyze(rs.normal(size=G.n_vertices))\n",
    "    Construct a measurement signal from a binary mask:\n",
    "    >>> mask = rs.uniform(0, 1, G.n_vertices) > 0.5\n",
    "    >>> measures = signal.copy()\n",
    "    >>> measures[~mask] = np.nan\n",
    "    Solve the regression problem by reconstructing the signal:\n",
    "    >>> recovery = learning.regression_tikhonov(G, measures, mask, tau=0)\n",
    "    Plot the results:\n",
    "    >>> fig, (ax1, ax2, ax3) = plt.subplots(1, 3, sharey=True, figsize=(10, 3))\n",
    "    >>> limits = [signal.min(), signal.max()]\n",
    "    >>> _ = G.plot_signal(signal, ax=ax1, limits=limits, title='Ground truth')\n",
    "    >>> _ = G.plot_signal(measures, ax=ax2, limits=limits, title='Measures')\n",
    "    >>> _ = G.plot_signal(recovery, ax=ax3, limits=limits, title='Recovery')\n",
    "    >>> _ = fig.tight_layout()\n",
    "    \"\"\"\n",
    "\n",
    "    if tau > 0:\n",
    "\n",
    "        y[M == False] = 0\n",
    "\n",
    "        if sparse.issparse(G.L):\n",
    "\n",
    "            def Op(x):\n",
    "                return (M * x.T).T + tau * (G.L.dot(x))\n",
    "\n",
    "            LinearOp = sparse.linalg.LinearOperator([G.N, G.N], Op)\n",
    "\n",
    "            if y.ndim > 1:\n",
    "                sol = np.empty(shape=y.shape)\n",
    "                res = np.empty(shape=y.shape[1])\n",
    "                for i in range(y.shape[1]):\n",
    "                    sol[:, i], res[i] = sparse.linalg.cg(\n",
    "                        LinearOp, y[:, i])\n",
    "            else:\n",
    "                sol, res = sparse.linalg.cg(LinearOp, y)\n",
    "\n",
    "            # TODO: do something with the residual...\n",
    "            return sol\n",
    "\n",
    "        else:\n",
    "\n",
    "            # Creating this matrix may be problematic in term of memory.\n",
    "            # Consider using an operator instead...\n",
    "            if type(G.L).__module__ == np.__name__:\n",
    "                LinearOp = np.diag(M*1) + tau * G.L\n",
    "            return np.linalg.solve(LinearOp, M * y)\n",
    "\n",
    "    else:\n",
    "\n",
    "        if np.prod(M.shape) != G.n_vertices:\n",
    "            raise ValueError(\"M should be of size [G.n_vertices,]\")\n",
    "\n",
    "        indl = M\n",
    "        indu = (M == False)\n",
    "\n",
    "        Luu = G.L[indu, :][:, indu]\n",
    "        Wul = - G.L[indu, :][:, indl]\n",
    "\n",
    "        if sparse.issparse(G.L):\n",
    "            sol_part = sparse.linalg.spsolve(Luu, Wul.dot(y[indl]))\n",
    "        else:\n",
    "            sol_part = np.linalg.solve(Luu, np.matmul(Wul, y[indl]))\n",
    "\n",
    "        sol = y.copy()\n",
    "        sol[indu] = sol_part\n",
    "\n",
    "        return sol"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "datapath = \"/mnt/nas/LTS2/datasets/ghcn-daily/processed/\"\n",
    "rawpath = \"/mnt/nas/LTS2/datasets/ghcn-daily/raw/\"\n",
    "newdatapath = \"./data/ghcn-daily/processed/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "years = np.arange(2010,2015)\n",
    "years2 = np.arange(2010,2018)\n",
    "\n",
    "feature_names = ['PRCP', 'TMIN', 'TMAX', 'SNOW', 'SNWD', 'WT']\n",
    "# max stations, 50'469\n",
    "# in 2010, in nbr stations [36k, 16k, 16k, 6k, 22k, (20k, 0, 0, 28, 1k, 6), 5k]\n",
    "# in 2011, in nbr stations [36k, 16k, 16k, 6k, 22k, (18k, 0, 0, 38, 1k, 22), 5k]\n",
    "# in 2014, in nbr stations [39k, 15k, 15k, 6k, 24k, (19k, 0, 0, 40, 1k, 49), 4k]\n",
    "n_features = len(feature_names)\n",
    "n_years  = len(years)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fetch data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_stations, ghcn_to_local, lat, lon, alt, _ = get_stations(datapath, years)\n",
    "full_data, n_days, valid_days = get_data(newdatapath, years, feature_names, ghcn_to_local)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert n_stations == full_data.shape[0]\n",
    "\n",
    "print(f'n_stations: {n_stations}, n_days: {n_days}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "neighbour = 10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preprocess data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Temperature MAX from Temperature MIN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EXP = 'temp'\n",
    "datas_temp, keep_temp, gTemp = clean_nodes(full_data, [1,3], lon, lat, figs=True, rad=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "min(gTemp.dw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(gTemp.dw, )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "leap_years = np.zeros_like(years).astype(np.bool)\n",
    "for i, in_year in enumerate(np.split(valid_days,len(years))):\n",
    "    leap_years[i] = in_year.sum()==366"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "w_months = np.tile(np.repeat(np.arange(12), 31), years[-1]-years[0]+1)[valid_days]\n",
    "w_days = np.tile(np.arange(365),years[-1]-years[0]+1)\n",
    "for i, leap in enumerate(leap_years):\n",
    "    if leap:\n",
    "        w_days = np.insert(w_days, ((i+1)*365), 365)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "w_days_sin = np.sin(w_days/367*np.pi)\n",
    "w_days_cos = -np.cos(w_days/367*np.pi*2)/2+0.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from GHCN_preprocessing import dataset_temp\n",
    "training, validation = dataset_temp(datas_temp, lon[keep_temp], lat[keep_temp], alt[keep_temp], w_days, add_feat=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "###### Super set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EXP += '_superset'\n",
    "datas_super, keepSuper, gTempSuper = clean_nodes(full_data, [1,3], lon, lat, superset=True, figs=False, rad=False, epsilon=False)\n",
    "mask = ~np.isnan(dataset_super)\n",
    "gTempSuper.compute_laplacian(\"combinatorial\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "min(gTempSuper.dw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(gTempSuper.dw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for day in range(dataset_super.shape[0]):\n",
    "    for feat in range(dataset_super.shape[-1]):\n",
    "        dataset_super[day,:,feat] = regression_tikhonov(gTempSuper, dataset_super[day,:,feat], mask[day,:,feat])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training, validation = dataset_temp(dataset_super, lon[keepSuper], lat[keepSuper], alt[keepSuper], w_days, add_feat=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.scatter(dataset_super[:,:,0], dataset_super[:,:,1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.where(dataset_super[:,:,0]> dataset_super[:,:,1])[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Precipirtation from temperatures (MIN, MAX)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EXP = 'prec'\n",
    "datas_prec, keep_prec, gPrec = clean_nodes(full_data, [0,3], lon, lat, figs=True, rad=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from GHCN_preprocessing import dataset_prec\n",
    "training, validation = dataset_prec(datas_prec, lon[keep_prec], lat[keep_prec], alt[keep_prec], w_days, add_feat=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_train, _ = training.get_all_data()\n",
    "time = np.empty_like(x_train[:,:,0])\n",
    "time[:,:] = np.arange(time.shape[0])[:,np.newaxis]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Find future temperature"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EXP = 'future'\n",
    "datas_temp_reg, keep_reg, gReg = clean_nodes(full_data, [1,3], lon, lat, figs=True, rad=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from GHCN_preprocessing import dataset_reg\n",
    "training, validation = dataset_reg(datas_temp_reg, lon[keep_reg], lat[keep_reg], alt[keep_reg], w_days_sin, add_feat=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "placereg = np.empty_like(datas_temp_reg[:,:,0])\n",
    "placereg[:,:] = np.arange(placereg.shape[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## is it possible to have a difference of 40 degrees between 1 day?  (must fix to +- 30 degrees)\n",
    "plt.scatter(x_train[:,:,days_pred-1], labels_train)\n",
    "plt.plot(range(-50, 45), range(-50, 45),'r')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### dense-classification, segmentation weather types"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EXP = 'seg'\n",
    "datas_seg, keep_seg, gSeg = clean_nodes(full_data, [-1,None], lon, lat, figs=False, superset=True, rad=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filt_data = []\n",
    "for i in range(1,3):\n",
    "    filt_data.append(datas_seg[:,:,i][~np.isnan(datas_seg[:,:,i])].flatten())\n",
    "plt.boxplot(filt_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dict_label = {1: 'fog', 2: 'heavy fog', 3: 'thunder', 4: 'ice pellets', 5: 'hail', 6: 'rime ice', 7: 'volcanic ash',\n",
    "              8: 'smoke', 9: 'blowing snow', 10: 'tornado', 11: 'high winds', 12: 'blowing spray', 13: 'mist',\n",
    "              14: 'drizzle', 15: 'freezing drizzle', 16: 'rain', 17: 'freezing rain', 18: 'snow', 19: 'unkown prec',\n",
    "              21: 'ground fog', 22: 'ice fog'}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fog = datas_seg[np.where(datas_seg[:,:,-1]==4)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "filt_fog = []\n",
    "for i in range(1,3):\n",
    "    filt_fog.append(fog[:,i][~np.isnan(fog[:,i])])\n",
    "plt.boxplot(filt_fog)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### snow prevision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "EXP = 'snow'\n",
    "datas_snow, keep_snow, gSnow = clean_nodes(full_data, [0,5], lon, lat, figs=True, rad=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from GHCN_preprocessing import dataset_snow\n",
    "training, validation = dataset_prec(datas_snow, lon[keep_snow], lat[keep_snow], alt[keep_snow], w_days_sin, add_feat=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Global regression task"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EXP = 'global_reg'\n",
    "plt.plot(w_days)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(w_days_sin)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(w_days_cos)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from GHCN_preprocessing import dataset_global\n",
    "training, validation = dataset_global(datas_prec, lon[keep_prec], lat[keep_prec], alt[keep_prec], w_days_cos, add_feat=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from GHCN_train import hyperparameters_dense, hyperparameters_global"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nfeat = training.get_all_data()[0].shape[-1]\n",
    "if 'global' in EXP:\n",
    "    params = hyperparameters_global(gPrec, EXP, neighbour, nfeat, training.N)\n",
    "else:\n",
    "    params = hyperparameters_dense(gReg, EXP, neighbour, nfeat, training.N)\n",
    "if 'super' in EXP:\n",
    "    params['mask'] = [mask[:limit,:,1], mask[limit:,:,1]]\n",
    "model = models.cgcnn(**params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "shutil.rmtree('summaries/{}/'.format(EXP_NAME), ignore_errors=True)\n",
    "shutil.rmtree('checkpoints/{}/'.format(EXP_NAME), ignore_errors=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "accuracy_validation, loss_validation, loss_training, t_step, t_batch = model.fit(training, validation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot.plot_loss(loss_training, loss_validation, t_step, params['eval_frequency'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_val, labels_val = validation.get_all_data()\n",
    "res = model.predict(x_val)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import r2_score, explained_variance_score, mean_absolute_error, mean_squared_error\n",
    "def mre(labels, predictions):\n",
    "    return np.mean(np.abs((labels - predictions) / np.clip(labels, 1, None))) * 100"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = model.predict(np.atleast_3d(x_val))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if 'global' not in EXP:\n",
    "    mse_ = (mean_squared_error(labels_val[:-1,:], res[1:,:]))\n",
    "    mae_ = (mean_absolute_error(labels_val[:-1,:], res[1:,:]))\n",
    "    mre_ = (mre(labels_val[:-1,:], res[1:,:]))\n",
    "    r2_ = (r2_score(labels_val[:-1,:], res[1:,:]))\n",
    "    expvar_ = (explained_variance_score(labels_val[:-1,:], res[1:,:]))\n",
    "    print(\"MSE={:.2f}, MAE={:.2f}, MRE={:.2f}, R2={:.3f}, Expvar={:.4f}\".format(mse_, mae_, mre_, r2_, expvar_))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "if 'future' in EXP:\n",
    "    predictions = []\n",
    "    for i in range(days_pred):\n",
    "        predictions.append(dataset_temp_reg[i,:,0])\n",
    "    for i in tqdm(range(len(dataset_temp_reg)-2*days_pred)):\n",
    "        x_pred = np.asarray(predictions[-days_pred:]).T\n",
    "        x_pred = np.hstack([x_pred, \n",
    "    #                      np.broadcast_to(w_months[np.newaxis,i:i+days_pred], x_pred.shape),\n",
    "                           coords_v,\n",
    "                           alt_v[:,np.newaxis],\n",
    "                           np.repeat(w_days_sin[i], x_pred.shape[0], axis=0)[:,np.newaxis]])\n",
    "                           #np.broadcast_to(w_days[np.newaxis, i:i+days_pred], x_pred.shape)])\n",
    "        x_pred = np.repeat(x_pred[np.newaxis,:,:], 64, axis=0)\n",
    "        res = model.predict(x_pred)\n",
    "        predictions.append(res[0,:])\n",
    "        \n",
    "    plt.scatter(timereg[days_pred:,:], np.asarray(predictions))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "res = model.predict(np.atleast_3d(x_val))\n",
    "plt.scatter(x_val[:,:,0], res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import mean_squared_error as mse\n",
    "from sklearn.metrics import mean_absolute_error as mae"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mse_loc = []\n",
    "mae_loc = []\n",
    "for i in range(res.shape[1]):\n",
    "#     mse_loc.append(mse(dataset_temp[:,i,1], res[:,i]))\n",
    "#     mae_loc.append(mae(dataset_temp[:,i,1], res[:,i]))\n",
    "#     mse_loc.append(mse(labels_train[:,i]*mask[:limit,i,1], res[:,i]*mask[:limit,i,1]))\n",
    "#     mae_loc.append(mae(labels_train[:,i]*mask[:limit,i,1], res[:,i]*mask[:limit,i,1]))\n",
    "    mse_loc.append(mse(labels_val[:,i], res[:,i]))\n",
    "    mae_loc.append(mae(labels_val[:,i], res[:,i]))\n",
    "plt.plot(mse_loc)\n",
    "plt.figure()\n",
    "plt.plot(mae_loc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mse_tab = []\n",
    "mae_tab = []\n",
    "for i in range(res.shape[0]-1):\n",
    "#     mse_tab.append(mse(dataset_temp[i,:,1], res[i,:]))\n",
    "#     mse_tab.append(mse(labels_train[i,:]*mask[i,:,1], res[i,:]*mask[i,:,1]))\n",
    "    mse_tab.append(mse(labels_val[i,:], res[i,:]))\n",
    "    mae_tab.append(mae(labels_val[i,:], res[i,:]))\n",
    "plt.plot(mse_tab)\n",
    "plt.figure()\n",
    "plt.plot(mae_tab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "arg = np.argmax(mse_tab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mse_loc = []\n",
    "mae_loc = []\n",
    "index = []\n",
    "for i in range(res.shape[1]):\n",
    "    mse_loc.append(mse([labels_val[arg,i]], [res[arg,i]]))\n",
    "    mae_loc.append(mae([labels_val[arg,i]], [res[arg,i]]))\n",
    "    if mae_loc[-1]>7:\n",
    "        index.append(i)\n",
    "plt.plot(mse_loc)\n",
    "plt.figure()\n",
    "plt.plot(mae_loc)\n",
    "index = np.asarray(index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(30, 50))\n",
    "ax = fig.add_subplot(3, 2, 1, projection=ccrs.Orthographic(-50, 90))\n",
    "ax.set_global()\n",
    "ax.coastlines(linewidth=2)\n",
    "ax.set_title('labels day D')\n",
    "\n",
    "zmin, zmax = 0, 40\n",
    "\n",
    "sc = ax.scatter(lon[keep], lat[keep], s=10,\n",
    "                c=np.clip(labels_val[arg, :], zmin, zmax), cmap=plt.get_cmap('RdYlBu_r'),\n",
    "#                 c=np.clip(labels_train[arg, :]*masknan[arg, :, 1], zmin, zmax), cmap=plt.get_cmap('RdYlBu_r'),\n",
    "                vmin=zmin, vmax=zmax, alpha=1, transform=ccrs.PlateCarree())\n",
    "\n",
    "ax = fig.add_subplot(3, 2, 2, projection=ccrs.Orthographic(-50, 90))\n",
    "ax.set_global()\n",
    "ax.coastlines(linewidth=2)\n",
    "ax.set_title('prediction day D')\n",
    "\n",
    "\n",
    "sc = ax.scatter(lon[keep], lat[keep], s=10,\n",
    "                c=np.clip(res[arg, :], zmin, zmax), cmap=plt.get_cmap('RdYlBu_r'),\n",
    "#                 c=np.clip(res[arg, :]*masknan[arg, :, 1], zmin, zmax), cmap=plt.get_cmap('RdYlBu_r'),\n",
    "                vmin=zmin, vmax=zmax, alpha=1, transform=ccrs.PlateCarree())\n",
    "\n",
    "ax = fig.add_subplot(3, 2, 4, projection=ccrs.Orthographic(-50, 90))\n",
    "ax.set_global()\n",
    "ax.coastlines(linewidth=2)\n",
    "ax.set_title('prediction day D+1')\n",
    "\n",
    "sc = ax.scatter(lon[keep], lat[keep], s=10,\n",
    "                c=np.clip(res[arg+1, :], zmin, zmax), cmap=plt.get_cmap('RdYlBu_r'),\n",
    "#                 c=np.clip(res[arg+1, :]*masknan[arg+1, :, 1], zmin, zmax), cmap=plt.get_cmap('RdYlBu_r'),\n",
    "                vmin=zmin, vmax=zmax, alpha=1, transform=ccrs.PlateCarree())\n",
    "\n",
    "# ax = fig.add_subplot(3, 2, 5, projection=ccrs.Orthographic(-50, 90))\n",
    "# ax.set_global()\n",
    "# ax.coastlines(linewidth=2)\n",
    "# ax.set_title('given D day D-1')\n",
    "\n",
    "\n",
    "# sc = ax.scatter(lon[keepToo], lat[keepToo], s=10,\n",
    "#                 c=np.clip(x_train[arg, :, days_pred-1], zmin, zmax), cmap=plt.get_cmap('RdYlBu_r'),\n",
    "#                 vmin=zmin, vmax=zmax, alpha=1, transform=ccrs.PlateCarree())\n",
    "\n",
    "ax = fig.add_subplot(3, 2, 3, projection=ccrs.Orthographic(-50, 90))\n",
    "ax.set_global()\n",
    "ax.coastlines(linewidth=2)\n",
    "ax.set_title('labels day D+1')\n",
    "\n",
    "\n",
    "sc = ax.scatter(lon[keep], lat[keep], s=10,\n",
    "                c=np.clip(labels_val[arg+1, :], zmin, zmax), cmap=plt.get_cmap('RdYlBu_r'),\n",
    "                vmin=zmin, vmax=zmax, alpha=1, transform=ccrs.PlateCarree())\n",
    "\n",
    "# ax = fig.add_subplot(3, 2, 6, projection=ccrs.Orthographic(-50, 90))\n",
    "# ax.set_global()\n",
    "# ax.coastlines(linewidth=2)\n",
    "# ax.set_title('given D+1 day D-1')\n",
    "\n",
    "# sc = ax.scatter(lon[keepToo], lat[keepToo], s=10,\n",
    "#                 c=np.clip(x_train[arg+1, :, days_pred-2], zmin, zmax), cmap=plt.get_cmap('RdYlBu_r'),\n",
    "#                 vmin=zmin, vmax=zmax, alpha=1, transform=ccrs.PlateCarree())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(30, 15))\n",
    "ax = fig.add_subplot(121, projection=ccrs.Orthographic(-50, 90))\n",
    "ax.set_global()\n",
    "ax.coastlines(linewidth=2)\n",
    "ax.set_title('big deviation')\n",
    "\n",
    "\n",
    "sc = ax.scatter(lon[keep][index], lat[keep][index], s=10,\n",
    "                c=np.clip(labels_train[arg, :][index], zmin, zmax), cmap=plt.get_cmap('RdYlBu_r'),\n",
    "                vmin=zmin, vmax=zmax, alpha=1, transform=ccrs.PlateCarree())\n",
    "\n",
    "ax = fig.add_subplot(122, projection=ccrs.Orthographic(-50, 90))\n",
    "ax.set_global()\n",
    "ax.coastlines(linewidth=2)\n",
    "ax.set_title('big deviation')\n",
    "\n",
    "\n",
    "sc = ax.scatter(lon[keep][index], lat[keep][index], s=10,\n",
    "                c=np.clip(res[arg, :][index], zmin, zmax), cmap=plt.get_cmap('RdYlBu_r'),\n",
    "                vmin=zmin, vmax=zmax, alpha=1, transform=ccrs.PlateCarree())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Figures"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pathfig = './figures/'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g_full = sphereGraph(lon, lat, 2, rad=False, epsilon=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(25,25))\n",
    "ax = fig.add_subplot(111, projection='3d')\n",
    "g_full.plot(vertex_size=10, edges=False, ax=ax)\n",
    "ax.set_title('world weather stations', fontsize=20)\n",
    "ax.view_init(azim=90)\n",
    "# make the panes transparent\n",
    "ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n",
    "ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n",
    "ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n",
    "# make the grid lines transparent\n",
    "ax.xaxis._axinfo[\"grid\"]['color'] =  (1,1,1,0)\n",
    "ax.yaxis._axinfo[\"grid\"]['color'] =  (1,1,1,0)\n",
    "ax.zaxis._axinfo[\"grid\"]['color'] =  (1,1,1,0)\n",
    "ax.axis('off')\n",
    "plt.savefig(pathfig+\"world_graph_USA.svg\", bboxes_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g_full.compute_laplacian(\"normalized\")\n",
    "g_full.compute_fourier_basis(recompute=True, n_eigenvectors=500)\n",
    "g_full.set_coordinates(g_full.U[:,1:4])\n",
    "g_full.plot(vertex_size=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(g_full.e[:16], 'o')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "us_lat = np.where(lat[keep_reg]>0)\n",
    "us = [np.where(np.logical_or(lon[keep_reg][us_lat]<-30, 0))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 10))\n",
    "ax = fig.add_subplot(1, 1, 1, projection=ccrs.PlateCarree())\n",
    "\n",
    "ax.set_global()\n",
    "# ax.stock_img()\n",
    "ax.coastlines()\n",
    "\n",
    "tmp = plt.plot(lon[keepToo][us_lat][us], lat[keepToo][us_lat][us], 'or', marker='o', markerfacecolor='r', markersize=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "neighbour = 8\n",
    "gReg = sphereGraph(lon[keepToo][us_lat][us], lat[keepToo][us_lat][us], neighbour, rad=False, epsilon=False)\n",
    "gReg.plot(vertex_size=10, edges=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(10,10))\n",
    "ax = fig.add_subplot(111, projection='3d')\n",
    "# gReg = sphereGraph(lon[keepToo][us_lat][us], lat[keepToo][us_lat][us], neighbour, rad=False, epsilon=False)\n",
    "gReg.plot(vertex_size=10, edges=True, ax=ax)\n",
    "ax.set_title('USA temperature stations', fontsize=20)\n",
    "ax.view_init(elev=0, azim=70)\n",
    "# make the panes transparent\n",
    "ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n",
    "ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n",
    "ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))\n",
    "# make the grid lines transparent\n",
    "ax.xaxis._axinfo[\"grid\"]['color'] =  (1,1,1,0)\n",
    "ax.yaxis._axinfo[\"grid\"]['color'] =  (1,1,1,0)\n",
    "ax.zaxis._axinfo[\"grid\"]['color'] =  (1,1,1,0)\n",
    "ax.axis('off')\n",
    "\n",
    "u, v = np.mgrid[0:2*np.pi:80j, 0:np.pi:40j]\n",
    "x = np.cos(u)*np.sin(v)\n",
    "y = np.sin(u)*np.sin(v)\n",
    "z = np.cos(v)\n",
    "ax.plot_surface(x, y, z, color=\"r\", alpha=0.1)\n",
    "\n",
    "plt.savefig(pathfig+\"USA_temp.svg\", bboxes_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.unravel_index(np.argmax(dataset_temp_reg), dataset_temp_reg.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "timereg = np.empty_like(dataset_temp_reg[:,:,0])\n",
    "timereg[:,:] = np.arange(timereg.shape[0])[:,np.newaxis]\n",
    "plt.scatter(timereg[:,50], dataset_temp_reg[:,50,0], s=10)\n",
    "plt.scatter(timereg[:,0], dataset_temp_reg[:,672,0], s=10)\n",
    "plt.scatter(timereg[:,10], dataset_temp_reg[:,10,0], s=10)\n",
    "plt.xlabel('day', fontsize=13)\n",
    "plt.ylabel('Temperature [°C]', fontsize=13)\n",
    "plt.savefig(pathfig+'evolution_temp.png', bboxes_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(w_days_cos*100)\n",
    "plt.plot(dataset_prec[:,2,0], 'o', markersize=2)\n",
    "plt.plot(dataset_prec[:,20,0], 'o', markersize=2)\n",
    "plt.ylabel('precipitations [mm]', fontsize=15)\n",
    "plt.xlabel('time [days]', fontsize=15)\n",
    "# plt.plot(dataset_prec[:,100,0], 'o', markersize=2)\n",
    "plt.savefig(pathfig+'prec_time.png', bboxes_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(20, 20))\n",
    "ax = fig.add_subplot(111, projection=ccrs.Orthographic(-50, 90))\n",
    "ax.set_global()\n",
    "ax.coastlines(linewidth=2)\n",
    "ax.set_title('', fontsize=20)\n",
    "\n",
    "zmin = -20\n",
    "zmax = 30\n",
    "\n",
    "sc = ax.scatter(lon[keepToo], lat[keepToo], s=20,\n",
    "                c=np.clip(dataset_temp_reg[50, :, 1], zmin, zmax), cmap=plt.get_cmap('RdYlBu_r'),\n",
    "                vmin=zmin, vmax=zmax, alpha=1, transform=ccrs.PlateCarree())\n",
    "\n",
    "cbar = plt.colorbar(sc, ax=ax)\n",
    "cbar.ax.tick_params(labelsize=30)\n",
    "\n",
    "plt.savefig(pathfig+\"temp_min.png\", bboxes_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if 'global' in EXP:\n",
    "    plt.plot(labels_val, label='labels')\n",
    "    plt.plot(res, label='predictions')\n",
    "    plt.xlabel('day', fontsize=15)\n",
    "    plt.ylabel('predicted day', fontsize=15)\n",
    "    plt.xticks(fontsize=15)\n",
    "    plt.yticks(fontsize=15)\n",
    "    plt.legend(fontsize=13)\n",
    "    # plt.ylim(top=1.1)\n",
    "    plt.savefig(pathfig+'global_linear_results.png', bboxes_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if 'future' in EXP:\n",
    "    plt.scatter(timereg[days_pred:,50], np.asarray(predictions)[:,50], s=10)\n",
    "    plt.scatter(timereg[days_pred:,0], np.asarray(predictions)[:,672], s=10)\n",
    "    plt.scatter(timereg[days_pred:,10], np.asarray(predictions)[:,10], s=10)\n",
    "    plt.savefig(pathfig+'prediction_temp.png', bboxes_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if 'prec' in EXP:\n",
    "# plt.scatter(np.atleast_3d(dataset_temp[:,:,1]), res)\n",
    "# plt.scatter(labels_train*mask[:limit,:,1], res*mask[:limit,:,1])\n",
    "plt.scatter(labels_train, res, s=2)\n",
    "plt.plot([0, 100],[0, 100], 'r')\n",
    "plt.xlabel('labels', fontsize=20)\n",
    "plt.ylabel('predictions', fontsize=20)\n",
    "plt.xticks(fontsize=15)\n",
    "plt.yticks(fontsize=15)\n",
    "# plt.plot(range(np.min(labels_train).astype(int), np.max(labels_train).astype(int)+1), \n",
    "#          range(np.min(labels_train).astype(int), np.max(labels_train).astype(int)+1), 'r')\n",
    "plt.savefig(pathfig+'prec_reg_results.png', bboxes_inches='tight')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
